"""
Phase 4: Threshold and Spline Estimation
==========================================
Test for nonlinearities in the demographic effect:

1. OADR splines: Does the Z₁ effect change above/below old-age dependency
   thresholds of 15%, 20%, 25%?

2. Income nonlinearity: Compare Z₁ effect across income terciles using
   a saturated model (low + high interactions).

3. KAOPEN nonlinearity: Continuous interaction vs saturated dummy.

4. Combined threshold model: OADR spline + income + openness.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"

Z_VARS = ['Z_1', 'Z_2', 'Z_3']

DVS = ['ca_gdp', 'gross_savings_gdp', 'gross_investment_gdp', 'nfa_gdp']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, p):
    return f"{val:.3f}{stars(p)}"


def run_gls(df, y_var, x_vars):
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None
    gls = PanelGLS()
    try:
        gls.fit(sub[y_var].values, sub[x_vars].values,
                sub['iso3'].values, sub['year'].values)
    except Exception:
        return None
    result = {'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
              'r_squared': gls.r_squared}
    for i, var in enumerate(x_vars):
        result[f'{var}_coef'] = gls.beta[i]
        result[f'{var}_se'] = gls.se[i]
        result[f'{var}_p'] = gls.pvalues[i]
    return result


def main():
    print("Phase 4: Threshold and Spline Estimation")
    print("=" * 70)

    df = pd.read_csv(DATA / "unified_panel.csv")
    print(f"Panel: {len(df)} obs\n")

    all_results = []

    for dv in DVS:
        if dv not in df.columns:
            continue
        print(f"\nDV: {dv}")
        print("-" * 50)

        # ── 1. Baseline ──────────────────────────────────────────────
        m0 = run_gls(df, dv, Z_VARS)
        if m0 is None:
            continue
        all_results.append({'dv': dv, 'model': 'Baseline',
                           'Z1': m0['Z_1_coef'], 'Z1_p': m0['Z_1_p'],
                           'n': m0['n_obs'], 'r2': m0['r_squared'],
                           'key_int': '', 'key_int_p': ''})

        # ── 2. OADR spline models ────────────────────────────────────
        for thresh in ['15', '20', '25']:
            suffix = f'oadr{thresh}'
            int_vars = [f'Z_1_x_{suffix}', f'Z_2_x_{suffix}', f'Z_3_x_{suffix}']
            if any(v not in df.columns for v in int_vars):
                continue
            x = Z_VARS + int_vars
            m = run_gls(df, dv, x)
            if m is None:
                continue
            z1x = m[f'Z_1_x_{suffix}_coef']
            z1x_p = m[f'Z_1_x_{suffix}_p']
            print(f"  OADR>{thresh}%: Z₁×spline={fmt(z1x, z1x_p)}, R²={m['r_squared']:.3f}")
            all_results.append({'dv': dv, 'model': f'OADR>{thresh}%',
                               'Z1': m['Z_1_coef'], 'Z1_p': m['Z_1_p'],
                               'n': m['n_obs'], 'r2': m['r_squared'],
                               'key_int': z1x, 'key_int_p': z1x_p})

        # ── 3. Income saturated model ────────────────────────────────
        int_inc = (['Z_1_x_low', 'Z_2_x_low', 'Z_3_x_low'] +
                   ['Z_1_x_high', 'Z_2_x_high', 'Z_3_x_high'])
        if all(v in df.columns for v in int_inc):
            x = Z_VARS + int_inc
            m = run_gls(df, dv, x)
            if m:
                z1_low = m['Z_1_x_low_coef']
                z1_low_p = m['Z_1_x_low_p']
                z1_high = m['Z_1_x_high_coef']
                z1_high_p = m['Z_1_x_high_p']
                base = m['Z_1_coef']
                print(f"  Income sat: base={fmt(base, m['Z_1_p'])}, "
                      f"×low={fmt(z1_low, z1_low_p)}, ×high={fmt(z1_high, z1_high_p)}")
                print(f"    → Low-income total: {base + z1_low:.3f}, "
                      f"Mid-income total: {base:.3f}, "
                      f"High-income total: {base + z1_high:.3f}")
                all_results.append({'dv': dv, 'model': 'Income saturated',
                                   'Z1': base, 'Z1_p': m['Z_1_p'],
                                   'n': m['n_obs'], 'r2': m['r_squared'],
                                   'key_int': f"low:{z1_low:.1f}, high:{z1_high:.1f}",
                                   'key_int_p': f"low:{z1_low_p:.3f}, high:{z1_high_p:.3f}"})

        # ── 4. Combined threshold: OADR@20 + income + OECD ──────────
        combined_int = (['Z_1_x_oadr20', 'Z_2_x_oadr20', 'Z_3_x_oadr20'] +
                       ['Z_1_x_high', 'Z_2_x_high', 'Z_3_x_high'] +
                       ['Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd'])
        if all(v in df.columns for v in combined_int):
            x = Z_VARS + combined_int
            m = run_gls(df, dv, x)
            if m:
                print(f"  Combined: Z₁×OADR20={fmt(m['Z_1_x_oadr20_coef'], m['Z_1_x_oadr20_p'])}, "
                      f"Z₁×High={fmt(m['Z_1_x_high_coef'], m['Z_1_x_high_p'])}, "
                      f"Z₁×OECD={fmt(m['Z_1_x_oecd_coef'], m['Z_1_x_oecd_p'])}")
                all_results.append({'dv': dv, 'model': 'Combined',
                                   'Z1': m['Z_1_coef'], 'Z1_p': m['Z_1_p'],
                                   'n': m['n_obs'], 'r2': m['r_squared'],
                                   'key_int': 'see table', 'key_int_p': ''})

    # ── Write output ──────────────────────────────────────────────────
    print(f"\n{'=' * 70}")
    print("Writing output tables...")

    res_df = pd.DataFrame(all_results)

    with open(OUT_TABLES / "phase4_thresholds.md", 'w') as f:
        f.write("# Phase 4: Threshold and Spline Estimates\n\n")
        for dv in DVS:
            sub = res_df[res_df['dv'] == dv]
            if sub.empty:
                continue
            f.write(f"\n## {dv}\n\n")
            f.write("| Model | Z₁ (base) | Key Interaction | N | R² |\n")
            f.write("|---|---|---|---|---|\n")
            for _, row in sub.iterrows():
                z1 = fmt(row['Z1'], row['Z1_p'])
                if isinstance(row['key_int'], str) and row['key_int']:
                    ki = row['key_int']
                elif isinstance(row['key_int'], (int, float)) and not pd.isna(row['key_int']):
                    ki = fmt(row['key_int'], row['key_int_p'])
                else:
                    ki = '--'
                f.write(f"| {row['model']} | {z1} | {ki} | "
                        f"{int(row['n'])} | {row['r2']:.3f} |\n")
        f.write("\n*OADR splines: Z₁×I(OADR>threshold). "
                "Positive = stronger effect above threshold.*\n")
        f.write("*PanelGLS with AR(1). \\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*\n")
    print("  Wrote: phase4_thresholds.md")

    # ── Detailed combined model table ─────────────────────────────────
    with open(OUT_TABLES / "phase4_combined_model.md", 'w') as f:
        f.write("# Phase 4: Combined Threshold Model (OADR@20 + Income + OECD)\n\n")
        f.write("| DV | Z₁ | Z₁×OADR20 | Z₁×High | Z₁×OECD | N | R² |\n")
        f.write("|---|---|---|---|---|---|---|\n")
        for dv in DVS:
            combined_int = (['Z_1_x_oadr20', 'Z_2_x_oadr20', 'Z_3_x_oadr20'] +
                           ['Z_1_x_high', 'Z_2_x_high', 'Z_3_x_high'] +
                           ['Z_1_x_oecd', 'Z_2_x_oecd', 'Z_3_x_oecd'])
            if any(v not in df.columns for v in combined_int):
                continue
            x = Z_VARS + combined_int
            m = run_gls(df, dv, x)
            if m is None:
                continue
            f.write(f"| {dv} | {fmt(m['Z_1_coef'], m['Z_1_p'])} | "
                    f"{fmt(m['Z_1_x_oadr20_coef'], m['Z_1_x_oadr20_p'])} | "
                    f"{fmt(m['Z_1_x_high_coef'], m['Z_1_x_high_p'])} | "
                    f"{fmt(m['Z_1_x_oecd_coef'], m['Z_1_x_oecd_p'])} | "
                    f"{m['n_obs']} | {m['r_squared']:.3f} |\n")
        f.write("\n*Combined model tests whether OADR aging threshold, income level, "
                "and institutional development independently moderate Z₁.*\n")
    print("  Wrote: phase4_combined_model.md")

    print(f"\nPhase 4 complete.")


if __name__ == '__main__':
    main()
